{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "colab_type": "code",
        "id": "6BaMKLQvE8OO",
        "outputId": "d9037787-b981-4975-d670-1ab5df35d47c"
      },
      "outputs": [
        {
          "ename": "ModuleNotFoundError",
          "evalue": "No module named 'torch'",
          "output_type": "error",
          "traceback": [
            "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[1;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
            "Cell \u001b[1;32mIn[2], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m      2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnn\u001b[39;00m\n\u001b[0;32m      3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moptim\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01moptim\u001b[39;00m\n",
            "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import numpy as np\n",
        "np.random.seed(0)\n",
        "torch.manual_seed(0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tu9_J0xrIHhj"
      },
      "outputs": [],
      "source": [
        "N = 100\n",
        "nb_epoch = 2000\n",
        "batch_size = 20\n",
        "nb_features = 1024\n",
        "Q = 1\n",
        "nb_output = 2  # total number of output\n",
        "D1 = 1  # first output\n",
        "D2 = 1  # second output"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ObaYZBVpyO3Z"
      },
      "source": [
        "# Evaluate on synthetic data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "aoKriQlrIM3A"
      },
      "outputs": [],
      "source": [
        "def gen_data(N):\n",
        "    X = np.random.randn(N, Q)\n",
        "    w1 = 2.\n",
        "    b1 = 8.\n",
        "    sigma1 = 1e1  # ground truth\n",
        "    Y1 = X.dot(w1) + b1 + sigma1 * np.random.randn(N, D1)\n",
        "    w2 = 3\n",
        "    b2 = 3.\n",
        "    sigma2 = 1e0  # ground truth\n",
        "    Y2 = X.dot(w2) + b2 + sigma2 * np.random.randn(N, D2)\n",
        "    return X, Y1, Y2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 129
        },
        "colab_type": "code",
        "id": "Db4B3XIzIPi-",
        "outputId": "9ce79361-d020-4c16-e123-a5b51872f786"
      },
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAM0AAABwCAYAAAC96PTCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHC1JREFUeJztnXlwHNd95z89PZgZgINjcFAEeEgiCTRlSbyVSLZkijQp0o4lsXQx4lo5bKXslHdLe7h2k0pqvavNxrtOstnK1pYrtuNSylo6ii2HK62V2GJ4SKREkeIFiiIbACmSIgESx1wYAtMz0937x2DAAdDdc/UMhmB/qlhFdPd0v5553/d+v9/7vfcEXddxcHDIH9dsF8DB4VbDEY2DQ4E4onFwKBBHNA4OBeKIxsGhQNyz+fChodGKhu4CgTpCobFKPrIklKRKJKbQ6PfirREL/vyt9r52YOc7t7XVC0bHZ1U0lcbtLrzizQaqpvHa3j5O9AwRjCo0N3hZ09XGjk3LEV35GwfZ71uqAKuBfN6hEr/xbSWaW4XX9vax58Mrk3+PRJXJv3du7iroXnYJcDaptne4Nb612wglqXKiZ8jw3ImeYZSkWtD9MgIciSro3BTga3v7bChtZai2d3BEU2VEYgrBqGJ4LjQaJxIzPmdEPJGyVYCzgd2NiB04oqkyGv1emhu8hucC9T4a/cbnjAhF7RPgbGFnI2IXjmiqDG+NyJquNsNza7paLZ14JakyGBqbbH0DDVYC9BYkwNnCzkbELpxAQBWyY9NyIG1+hEbjBOp9rOlqZfsjdzMYGpsRPRpTkux6u5dzl4KERhOTjvK/fG4Ndb4aRgxaaiWp4hYNI6pVRaYRyQ6MZMjViJQLYTaznCs9TtPWVs/Q0GglHwkUH+7NfM5f52H3uxdmRI+eeXQpP9t/gYPdA8QTM237L332Lg6f7ic4mjC8/8Y1HbywdUVJZcxVdjvudzN6NrURMYqe2fkbm43TOKIpI3aFSnft6TFsaRfP9/PpYMz0c80NXkITEScjmvwe/uT3HjQUZLHh3HKGh/MRYiVE45hnZcSO8Rar6NHVIXPBQDoQUF9XQ3QsaXg+Ekvwk7d7OPTRtZLKmI2dY0zT8daIzA/UlXQPO3ACAWXCrlCpVfRIy9FPtwVqTYMKkA4GnLscKrmMGaoxPFwOHNGUCbtCpVbRI1cOP/7B+9r5ymNdLJ7vNzy/4s6AaRlHonGC0XheZcxQjeHhcuCIpkzYFSq1CkEvbDMWg88jsnn9Ir76+L2ILhf/8XfWs3FNB01+DwLQ0uBj8/pF7NzSaVpGgD0ffppXGTNUY3g4w/RwfCk4Pk2ZsDNUahaCzkTPbh73smJJgOe3dFHndSOK6TZRdLl4YesKnts005G+b1kLB070Gz63+3wQJanmXdZqDA+XIzDhiKaMmFX2zPF8EV0udm7u4ukNy2ZUerPjRhg50smkZnp9xqQqxPm2653tohyBCUc0ZcSqsheDWfSo2KiSklSRTQIBAE1+86wBs/BvKe9cjrEiq8DE0xuWFfUcRzQVwKpSz+Y8FyvHHdKBgullytfcKUTIqqaxa08vJ3uGCcesTahCvq98AhPFNDYliUaSpO8Cj0zc5zvAUeDHgAgMAC/Isjw3QiY2U6k5IlaVLOO4G6XZ+DwiO7d0zjhejLljVQZV03j5lQ+nDNIa3bOY78vq/UoJTBQtGkmSNgL3ybL8kCRJLcAJ4J+B/y3L8k8lSfpT4KvA94p9xlymnIOAkK5kP9h9mkOnrppWMivH/eGV7dR5a6YcU5Iqx+VBw+cZmTv5VPRdb/eYZjVk39Pq+8o2BbMpV2CilJ7mHeDIxP/DwDzgUeAbE8feBL6FI5pJMi1urdddFls7m3xFma/jrmoar/5SNs1jMzJ3cpVBSaqc6B02fYeRaHxSDGbf18HuAY7Lg5OJqp9btZDHH1oyKcpyBCaKFo0syypwY+LPrwFvAVuzzLFBoN3qHoFAXcXn7be11Vf0eQCqqvGjN89w+KMBhsLjNNf7DE0GSFc+0VNDW+u8op8XT6ToPj9ieK77/Ahff7oWn+fmT//S8+uIJ1KEogqBBu+Ucxl+sPv0lHSb6bQ21bLsrpbJz+ZThlRUIRwzFiGkB2/fOX2Nxx9ZSnDU+PuKJ9TJZNWRqMIb714A4Pe231/Q+xVCyYEASZKeJC2ax4DerFM5884rvVLKbGU5T0+4HLEYaQ/U+1ATySnlLDRYMBgaYyg0bnhuODzO+Ysjkz3C9HuPRsaZ/g0pSZVDp65aPnPlspYpn7Uqw9BEGRr9XlpMfA5Ipwm99d5FxscTNNebXzedQ6f6+eKvLZ7xXbnB8P3MMGtgSw0EbAX+CNgmy3JEkqSYJEm1siyPAwsB41Gz2wirsKcR2bZ2tk8wElVo8ntY09nKzi1dlsGCfBzgQhzrXFG2z963YIa5Y1UGAfjlkcvs3NJl6nNk030+yMrlrew7bi3cDKVExvKh6DCNJEmNwJ8BX5ZlOThxeA/w9MT/nwb+qbTi3frkqnCZ7ri53svm9YumVL7sBSUAwrEE+0708/IrH6Jq5oOSblGgzldjeG7FkqYZ9861WIVVekxLg5cXtkozhGaV/qPpsO9EP6/t7WP7I3fj81j3nqHROJvXLWLz+kW0NPhwCenn+jzG1bfcKTul9DQ7gFbg7yVJyhz7beCHkiR9HbgE/G1pxbv1sWpxgcm5Lqs6W6c46FY91KeDMXbt6eWFxyTD86/t7TOMSIkuOPTRNc5dDnEjbjxdwCgQYR2FajM1GXdsWo6q6Rw4cdUwI/tEzzCfX9WBYjCBLptAvY/mBt+MQdPXD5yflZSdUgIB3we+b3BqS/HFmXtYVbhsuvtGUDbezPOKxBRLG/5kzzDPbVw+o3JYiU2d6Jys7puJWE03bYqJQokuF1sfWGxqVoVG46DreD2i4czTDNkiyB40NSrT51Z18PhDS0zvZQdORkAFyLS4+09cxWyi7Eg0zoWrEZYubMRbI9Lo99Lk95hGl8I3FMPKncsczId/PHKZr0zzm4pNj7Fy9m+aUeYTgzau6TAVplGZFnU0lT3Y40wNqACZFtdqZrlLgD//u5P88Q8Os2tPD25RYE1nq+n1zSZ2u5X/kS8HJvwNIzItfUYwuVLuvTUiq0zeY1VnC+NKinjC3D/b+mtLcmZITC9TuXF6mgqRT3gVpg0Abumi72rU0D8xs9vzNQdzcVweshxkLST6Zjb2IGD9vbQ0zO4cHDOcnqZCWEWTjDjRM0xK1dMTyNYuJOD3IggQ8HvZuHahpS+xY9NyNq9fxPxALYIA3prCf+bQqGI50zLf6JuSVDlpMup/sjc9+FnsOm+zhSOaCpKpzC0NPgSgvs44LAw3xxrSdnsnqztbaJznIRRT6O4b5rW9faZhZ9HlYsem5ay/5w4a6mpQps2Z8XlEOlqtxzCsFhMsZC2AfDKNs7+XdDjZNyP8XioJNcHQ2AgJ1TwDIV8c86yCTHdca71uXn7laM4s3Nf29rEva3ZlPsmd0/O+soknVO65M4DocpkmS66VzEPJhaTc5zPQave8o2xUTeXnfb+ge+gMISVMwNvEyrZ7eWr5byC6inuG09MUgF3zzDOOa32dJ6dpUswKL/lkIZzsHeE//Iu1PLqmY4r55vOIbFpnbf4VshZAIcvsFurQxxIx5GAfsYT5UlY/7/sF+68cJKiE0NEJKiH2XznIz/t+kdczjHB6mjwo59yXXOMfxUykyifsHBqNExtL8FtbV7BjUydD4XHQddryqLSFptzbmWmcUBMMj43wo49/wvUbg2houBCYX9fGv1n7Ddqon3Jt99AZw/ucHj7Dk8u24RE9BZfBEU0elHPuSy7TJB/zZnrSZa4sBICGeR5qvemf31sjsshkZRszChGCqqfY8tkWNq5v49JQkM4Fd9BSP/N5CTVBRBml0Vs/ozKrmsrPet/g1PAZIkp0yjkNnWtjg/zhwT9hW9cGtnU8hugSiSijhJSwYfmD8TARZZS2upaC3hsc0eSkXPPMp2M1/9+sVV/V2cLrB84b9oCrOlvZe8w8wTEcS/DyK0cL7jGzBWol9oSaIBgPs//TQ3w0cpZQPAw66AII5320u5fx7x/dicddMymIk4MfEU2OEvA2sqrt/km/Q9VU/vuH/4urMev8Xw2Nt3r2MTaW5NmuJ2j01hPwNhFUZq6D4BU9+D3FJXTeEqKp1nn05c6mzWDWquu6btgDqpqeM58r+3rI3WNamahNDW4iSphGVz2qrvLTnjfoCZ2f2soL6X8CgCfOAGf47v5d/OGmr/Dfjv4V/TcGJi8NKRH2XzlIUkuwc8Uz/LTnjZyCySbb9FrZdi/7rxyccU1cVfh/F97m2a4n8r5vhqoWTaX3WjQSZ7nmmReCkQkH8Mc/OGx4vVmC5CQuFaFGQU+m73Psk0s8Hl9Mva/WtIGaNFFdKkLtGMFkmH8+M8R5DpGovUZICeNxeUhqSTTMR/izGUhdYNfZ16cIJptD/UdAh9PDZ/O6X4Zs0+vLd2/h/YGjKOrM369Yv6aqRVPuefQZVE1j19s9nOgdJhxL0JIlTjvnmZfaY2abcAMjN3JmFwBTBaIJuBfLiIHrCN44qCK6DuNule8cPUKdsohwz1JC0eSUBiql6hzvuY578ce4266CeLMXuyYAE8VQtMJy3vSacU4NGjvqGQ4NHLE8b0Szr4lGbzogEEuOmY7NFOvXVK1oKuVL5LMaSqnRn3L0mLmXjNWmCERXfOipGkR/VjKjW51McYkkI0RcEZLNI+ix+6d8B5vXLSLadIKaBaWl5kxH0ETGBePZndm4cOXdewHc33rvZO9h5dcEvDfFVQhVK5py+BLxRGrGTmK79vTmtRpKKYNvdveYSlI1nX+fwb1Ypqb90uTfgi8O5F7Q3N12DbF5EHWog9Snn+F4zyBqezc1d9grGACXS6De3UAkFbW8zkwwC+rmM7+2jU9HrxJORGj2NfHrS1azreOxyWs8ogfveAe4DIIB4x1zK+Rspy+Raem7z48wFBqfbOm3P3I3J3vMV0MJZs0tKWU3M7t7zBkNyhQTTAR3ArHZfBEMKwQBBLeGq/0Krvowo7FmDl+/nMeKD4WjCSlWtHbywbVjltc1ewPc27KCj4PnGImHaPI0cH/rvTzb9QSiS5wSql64oGXG+grhnqUkG2OIgUEEzzh6ohY1NJ9wZBnKhvzXqs5QtaKx05cwa+nH4inCFkmJjX4P/roadu3pKdq0srPHzAhXJUlDIEkkAu5FvWkTzKOgJ7zoqhtBTCB4jGdmFoLoj0Ft+RY/ERB4cuk2PC4v7/UfRjXpUVa2pQViNo7jET2mfkkkphCKJtGj95C62jWlcQkLiaIslqoVDdgzkmzV0p+7FLIcBFzT2crudz/J27QqV/RN1TR27T3HsYu9jDX2IM6LIixX8GkgZLUdgldh0iu3C1f+vkSh6OgktBS/uWI7T3V+iSvRAX51eR9XRvsnza37W9N5YmAtDjOmfP+aiK7cFEix0c+qFo0diXxWLX04pvDQvQsM1/NaPN/P048u59t/84HhZ7NNKytHP1ePKbhUhsbC+Fx1jI/rU94xPUAY4fvvv8k1XYa7dWqyzaRKDFmVaJZ5XV7TqFpzliPuET0sDdzJNwK/Y5kZUPDzy7DKZlWLJkMpey3maumf39JFrc/NiZ5hgqNxmuZ5Wd3Vys7NnYxE4nmZVrkc/UzPeLz3GuFEiIZ5HlbftRj34rP8l8OvE4yHIFlLcmQ+DZGVrOwKoLd/RE/oPOFEBFxlcSnyQ8fw4QICusU05QwPtq+jL3LRcHByZdt9hqIopkexwu5VNm8J0ZRCrpamzus27c3yzfs63juA4B276YgDuNTJQcM6jwdx8VlU91G8WgIFOIoL7eqE6SMAnnFq2i8Rqx/hsDaOcL169qd8cMED9Ib7CMbD1Oh1qOFmbpxfjv/uS4jNgySEGwS8jfjctdxIjBFNRmeYVj/teYPTw2cIJ6K0+AJTzhVLvsEZu6ce2L4luiRJfwk8SLqNekmW5aNm11ZqS/Ts6NlweNxyH/rp/HjPx+w/fWGqIGBiktQy/s+Z3bz/6UkEr4KueFCjLaCJiE3DCN44TZ5G5nnqTEe9qx31hp+XH/4WTQ1ufryvm0PHQ1O+B1wqn1sb4IWNK/GIHlPTSkmqDEdHoSZB27ymkswuK3N4wR2Nt9aW6JIkbQA6J3YSuAf4EfCQnc8ohkxL8/WnayeXQ83V0mQmL/XUnsG3MgSpWlLhAP7gatZ1pldI+Vnvm3ww9AEuX/ozgi+ByzdVHJFkhEgyUq5Xs5Up7acO2lg98658nka/F12Dcz3KVMEAaCLnehLoG0QQZ5pW5RjYtTKHX3p+XVH3LAS7zbMvALsBZFk+K0lSQJKkBlmWrUevKoTP487LN0qoCf5O/oeb4wcCUDOOu20c7Y5hXAvWkVAX88HAh+UtcCloFDzFUA03k7ws4fKk0MbrIeXh8+vb8daIXBkctVy03Ww8qxwDu1bjXvFEquB7FordolkAZI9UDU0cMxRNpXcNUFIJ1No4AV8jXrdnxrn+6CBv9rzN2es9jMSN52EktATv9r9PX/RCwblWlSQ1uIRHVi/g8LWjOR12L37ig60kL3YBLjQFar1uvvDgYn73Nz7D3751lvc/MjcvW5tquWtxgF2/lCd3RmhrqmX9PXdwqs948Nho94J8GBi+YbqDQGg0Tiiq0F7mnSHKHQiwDPpUateAjKl1JniW4bHglHniAK/3vsnha8cMM2HNGIhdN40szSZaUkQdXkRDZBW//9AWao6IvHP1fcPrhPAilP4ljCd9M8yuBz8zn6cevpvvvX4q53JQK5e18MPdp6dcNxga5633Lpp+ZvruBfmiJlXTHQQC9T4CDV47fRrD43aLpp90z5Khg/Q2grNKZp54hsw88QwHrr5X1H3LrRldT6e15LxOAz3hQx1tJnXpHtBqWLt+Pj6Pm2c6n0DAxeErp1CIoSk+iLWgXFwBmvlqON3ng4yOJSzXGmiu97JWauNLDy7h5VeMTVWXgOE0hWIHFnNFQ30ed95baRSL3aL5FfCfgb+WJGkt0C/LcuU3hMnCap5499BHeY01mGGHYHQNBBPfI3W9A0QQ64MInjgeoRZR9xJX4+hiHCFVywL3XSzQP4N8SSEcTdEybQxCdIk8Jz3J40u38uq+03z8yQ0io7lH+UOjca4MxkzHqQQB/tUzKzl0eoCXX/nQdPlcs3k9paxpNjnuJQ8RGlUITIi3Utuu2yoaWZbfkyTpmCRJ75F2Rb9p5/2LwXKeuMnxfBFStSRG2nC3XkFw559uouugJ9LhaX9wFQ9sjHDk2jESWrrieV0e6pWlxEKdRGJJGhrc3NNZh1ur5cCx61MSND/RRJatD/BfX7Qeg/iHA5d5/3j+7Veg3sei+X7Tcarmeh/vnOrPuWeMS4CHV7Vz5kLItu37MmR64Xx6Yzux3aeRZfkP7L5nKWkVVvMpmr1N6OiElOJCwq5oO6nLXaSudOJechaxYeRm4mTKg8t3w1BMqaEOUpfvBU1k/fpFPNf1KKkrEicuXiQaS1LrbULqbGf7i0uJjSVo9HtJJFW+/aOJCVnTcqgyKT1W264XsrEUwOrOlsklpoxMoZXLmuk2cfKz0XT40q/fyfNf6LJtyvpcCznbih0LvVnNE1/Zdh+A4bnpdMxbwFhqjIgySrOvieX1XRw40pg+qdWQuriS1PQUfVcS/7IexKYgSWEMjz4PNTgf5dJyWurrJlvc1/b2se/Da4AP8BGMq5OVIHP+2LkhUxMoV7Z0MTsJpCb25TBLQdm4ZiH7T+Set988sVJnKalQ2czFkLOtWDnwhSyIkImSfRw8y9BYcEaKh65r6ehZKr0usa6BoAvg0tETtSz0LOUPHvgKqq5O9ni6JnL63cNTTZeJHsAlwMOr29n6wBKaGzYjuKZ+LrvFzVUJVE3PaQLlcqrzWdJpOu+eGsA1sSSuUQqKklTzuqfVSp3FkGuqRSiqlL1SV61o7FzoTXSJPNv1BA2BZzl/tX+GmfectJ0v3rmVb7+6j0gsiZ5It4iZXiPin0dqA3hrska7RUxNlw2rO3hh64rsEkz5XHaLa1UJgtG45SS5DLmc6mJ2EtB02Hf8KqJLYOfmrhk9Ra57tjTY57tkkysfMNDgZTSSewp1KVStaMqx0JvXbZ49Oz6uExn2oeObPJbxG8zMHzuyZ60qQaPfYzlJrsnvYf2K+Xk9z6isK5e3cKp3iOCo+aLgVrNLze65ed0imht8ZVluay6GnG3D0oH3FbcgguXzLCqvp0bEb7DCf67s2XyycC0rQWcr3edHjFtVv5f/9NUHqK/Lt7c1LqvoEix7ICt/qZwLl1thd6p/oVStaKwc+OzVRuzCqvLGEyq73/3ENFdquulSaJKiVSUQRePV/9etaMtbMFZl3bFpOaqqceBkf9GDkHY5+fkyW2LNULWigZsO/OnhMwTj4RkOvN1sf2QpB7v7DbezK2QRjEKTFK0qQblbVdHlSvtfgmAYcKjWjZWg8mLNUNWiyTjwTy7bZtv0VytiYwkUk/0f810Eo5TVZ4wqQaVa1Z2bOxFdwqyZPLcSVS2aDHZPfzXDjkUwyrX2c7lb1dk2eW4lnE2dsihkAyIzCtnwqBqp9E7JtyKOaKZR6v6PdgjPobq5JcyzSmKHmTLbIVGH8uKIxoRSfAjHP5jbOKIpI7MVEnUoL45P4+BQII5oHBwKxBGNg0OBOKLJgZJUGQyNoSSrZ5lYh9nFCQSYUOlNch1uHRzRmFCpTXIdbj2cJtOAXEmXjql2e+OIxoB8ki4dbl+KMs8kSXIDfwMsm7jHt2RZPihJ0irge6QXn+yWZfn3bStpBbFzk1yHuUexPc0LwA1Zlh8Gvgb8j4nj/5P0njSfAxolSfqiDWWsOE7SpYMVxQYCXgV+MvH/IaBFkiQPcHfWJk5vApuBfyytiLODk3TpYEZRopFlOQlk9tz+18AuoBXIXgVjEGi3uo/ZTlPlxGwleCNeen4dj/+7/1sHtI9E4wMvPb+uMtsc2Egh7ztXKPc75xSNJEkvAi9OO/xtWZZ/KUnSN4G1wOPAdHumyjahKI43/+LJMeD8bJfDoXrIKRpZln8I/HD6cUmSvkZaLNtlWU5KkjQEZM9JXkh66w0HhzlFUYEASZKWAt8AnpJlOQ6TJts5SZIenrjsKeCfbCmlg0MVUdTuzpIk/Snwm8DlrMOPAcuBvyYtxg9kWf63dhTSwaGasH1LdAeHuY6TEeDgUCCOaBwcCuS2ynI2S/+Z3VKVB0mS/hJ4kHRK00tZg85zEkmSvgs8Qvp3/Y4syz8v17Nut57GLP1nTiFJ0gagU5blh0i/51/NcpHKiiRJG4H7Jt53G+l0rrJxu4nmVSAT0Zs+rjSX+AKwG0CW5bNAQJKkhtktUll5B3h24v9hYJ4kSWVLELytzDOT9J+5yALgWNbfQxPHorNTnPIiy7IK3Jj482vAWxPHysKcFU0B6T+3A3MipSkXkiQ9SVo0j5XzOXNWNPmm/1S8YJWhn3TPkqEDGJilslQESZK2An8EbJNlubg97vPktvJpjNJ/5ii/Ap4BkCRpLdAvy3K5t6KcNSRJagT+DPiyLMvBcj9vzvY0JrxI2vl/S5KkzLHHZFk236n1FkSW5fckSTomSdJ7gAZ8c7bLVGZ2kJ6a8vdZv+tvybJ82fwjxeOk0Tg4FMhtZZ45ONiBIxoHhwJxROPgUCCOaBwcCsQRjYNDgTiicXAoEEc0Dg4F8v8BVsZonU7svPQAAAAASUVORK5CYII=",
            "text/plain": [
              "<Figure size 216x108 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "import pylab\n",
        "%matplotlib inline\n",
        "\n",
        "X, Y1, Y2 = gen_data(N)\n",
        "pylab.figure(figsize=(3, 1.5))\n",
        "pylab.scatter(X[:, 0], Y1[:, 0])\n",
        "pylab.scatter(X[:, 0], Y2[:, 0])\n",
        "pylab.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_E7LhkkoyR72"
      },
      "source": [
        "# Example Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "oFKJ9_mfIUlQ"
      },
      "outputs": [],
      "source": [
        "class Net(nn.Module):\n",
        "    def __init__(self, input_size, hidden_size, output1_size, output2_size):\n",
        "        super(Net, self).__init__()\n",
        "        self.fc1 = nn.Linear(input_size, hidden_size) \n",
        "        self.relu = nn.ReLU()\n",
        "        self.fc2 = nn.Linear(hidden_size, output1_size)\n",
        "        self.fc3 = nn.Linear(hidden_size, output2_size)\n",
        "    \n",
        "    def forward(self, x):\n",
        "        out = self.fc1(x)\n",
        "        out = self.relu(out)\n",
        "        out1 = self.fc2(out)\n",
        "        out2 = self.fc3(out)\n",
        "        return out1, out2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "r-UlDkdlK21o"
      },
      "outputs": [],
      "source": [
        "model = Net(Q, nb_features, D1, D2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "EyR5ree1yWDk"
      },
      "source": [
        "## Define task dependent log_variance"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ih0Ocq-NFuX2"
      },
      "outputs": [],
      "source": [
        "log_var_a = torch.zeros((1,), requires_grad=True)\n",
        "log_var_b = torch.zeros((1,), requires_grad=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "colab_type": "code",
        "id": "Q_aXEqn84eGR",
        "outputId": "b43872e7-3303-4123-9e28-f5ff9746cae6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[1.0, 1.0]\n"
          ]
        }
      ],
      "source": [
        "# Initialized standard deviations (ground truth is 10 and 1):\n",
        "std_1 = torch.exp(log_var_a)**0.5\n",
        "std_2 = torch.exp(log_var_b)**0.5\n",
        "print([std_1.item(), std_2.item()])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mcKDv3htMqXJ"
      },
      "outputs": [],
      "source": [
        "# get all parameters (model parameters + task dependent log variances)\n",
        "params = ([p for p in model.parameters()] + [log_var_a] + [log_var_b])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "HfBN-qs2Q_Wj"
      },
      "outputs": [],
      "source": [
        "#optimizer = optim.SGD(params, lr=0.001, momentum=0.9)\n",
        "optimizer = optim.Adam(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MGPCPAXryqT9"
      },
      "source": [
        "## define loss criterion"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "5rHlOTALGXug"
      },
      "outputs": [],
      "source": [
        "def criterion(y_pred, y_true, log_vars):\n",
        "  loss = 0\n",
        "  for i in range(len(y_pred)):\n",
        "    precision = torch.exp(-log_vars[i])\n",
        "    diff = (y_pred[i]-y_true[i])**2.\n",
        "    loss += torch.sum(precision * diff + log_vars[i], -1)\n",
        "  return torch.mean(loss)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9Mk7ma7my397"
      },
      "source": [
        "## Train the network"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "_pQ3cDYGVtPv"
      },
      "outputs": [],
      "source": [
        "# convert data into torch from numpy array\n",
        "X = X.astype('float32')\n",
        "Y1 = Y1.astype('float32')\n",
        "Y2 = Y2.astype('float32')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mFm-Jtkm2cGk"
      },
      "outputs": [],
      "source": [
        "def shuffle_data(X,Y1,Y2):\n",
        "    s = np.arange(X.shape[0])\n",
        "    np.random.shuffle(s)\n",
        "    return X[s], Y1[s], Y2[s]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2RJ0U_ss0uT1"
      },
      "outputs": [],
      "source": [
        "## Train Network\n",
        "loss_history = np.zeros(nb_epoch)\n",
        "\n",
        "for i in range(nb_epoch):\n",
        "\n",
        "    epoch_loss = 0\n",
        "    \n",
        "    X, Y1, Y2 = shuffle_data(X, Y1, Y2)\n",
        "    \n",
        "    for j in range(N//batch_size):\n",
        "        \n",
        "        optimizer.zero_grad()\n",
        "        \n",
        "        inp = torch.from_numpy(X[(j*batch_size):((j+1)*batch_size)])\n",
        "        target1 = torch.from_numpy(Y1[(j*batch_size):((j+1)*batch_size)])\n",
        "        target2 = torch.from_numpy(Y2[(j*batch_size):((j+1)*batch_size)])\n",
        "        \n",
        "        out = model(inp)\n",
        "        \n",
        "        loss = criterion(out, [target1, target2], [log_var_a, log_var_b])\n",
        "        \n",
        "        epoch_loss += loss.item()\n",
        "        \n",
        "        loss.backward()\n",
        "        \n",
        "        optimizer.step()\n",
        "   \n",
        "    loss_history[i] = epoch_loss * batch_size / N    "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 282
        },
        "colab_type": "code",
        "id": "ZoiJCMsb1BMt",
        "outputId": "270cbd11-8e7f-4426-acfa-3bf789a1fc8f"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<matplotlib.lines.Line2D at 0x7fec2dfdf6a0>]"
            ]
          },
          "execution_count": 15,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAD4CAYAAAATpHZ6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl4XPV97/H3mU27Jdka4zVm9ZdgdtthCRQTliQ0S28hTRPCzUKfpA1Jk6akJTe3pJC0uReacHsJT1IKgYQkt2FpiskCAROWAMFLcMCAf17YvGLJlmVLlqWZ0dw/zpE8EiN5NJrRLPq8HvzM6JwzM59nNHzO0e+cOcdLp9OIiEh1CJU6gIiIFI5KXUSkiqjURUSqiEpdRKSKqNRFRKpIpJQv3t6+P+9Db1pb6+nsPFDIOAWhXONTrrmgfLMp1/hUY654vMkbbV7FbqlHIuFSR8hKucanXHNB+WZTrvGZarkqttRFROStVOoiIlVEpS4iUkVU6iIiVUSlLiJSRVTqIiJVJKfj1M3sBuDcYPlvAquAu4AwsAO4wjnXZ2aXA18EBoBbnXO3FyW1iIhkddgtdTM7HzjROXcW8B7g/wDXA7c4584FNgGfMrMG4FrgQmAZ8DdmNr0Yobfu6uZHD77MgE4bLCIyTC7DL08AHwru7wUa8Et7eTDtAfwiPwNY5Zzrcs71Ak8B7yxo2sDjf9jOTx/ewM7d5fctMRGRUjrs8ItzLgX0BD9eCfwSeLdzri+YtguYDcwC2jMeOjh9VK2t9Xl9q6q2NgpAc0s98XjTuB9fbOWYCZQrH+WaTbnGZyrlyvncL2b2QfxSvxjYmDFrtHMQjHpugkH5nvfgYG8CgD17emiIHPZlJlU83kR7+/5Sx3gL5Rq/cs2mXONTjbnGWhnkdPSLmb0b+CrwXudcF9BtZnXB7LnA9uDfrIyHDU4vOC/ocV2KT0RkuFx2lDYDNwLvc87tCSY/Alwa3L8UeBB4FlhqZi1m1og/nv5k4SMz9DeAOl1EZLhchl8+DLQBd5vZ4LSPA7eZ2WeA14EfOOcSZnYN8BCQBq4LtuoLLuSV15CLiEi5yGVH6a3ArVlmXZRl2XuBewuQKyc6pFFEZLiK/EapNtRFRLKr0FL3W10b6iIiw1VmqQe3OvpFRGS4yiz1wS31EucQESk3FVrq/q221EVEhqvIUh+kThcRGa4iS93T4S8iIllVZqkHtxp+EREZrjJLXacJEBHJqkJLffA4dbW6iEimyiz14FaVLiIyXGWW+uDwS2ljiIiUnYosdTT8IiKSVUWWekjjLyIiWVVkqQ8aUKmLiAxTkaV+6MtHanURkUw5XXjazE4E7gducs59x8zuAeLB7OnA74B/Bl4A1gTT251zHypwXkDHqYuIjOawpW5mDcDNwIrBaZllbWbfB247NMstK3DGt/DQ+dRFRLLJZfilD7gE2D5yhvkXLW1xzq0sdLCxHDqkUa0uIpIpl2uUJoFkxkWnM30Bfyt+0CwzuxeYA9zinPvxWM/d2lpPJBIeR1xfY2MtAE1NdcTjTeN+fLGVYyZQrnyUazblGp+plCunMfVszCwGnOOc+2wwaTfwD8CPgGZgpZk96pzbMdpzdHYeyOu1e3r6AOjq6qW9fX9ez1Es8XhT2WUC5cpHuWZTrvGpxlxjrQzyLnXgPGBo2MU5tx+4I/ixw8xWA8cDo5Z6vg6deVfDLyIimSZySONS4A+DP5jZ+Wb27eB+A3AqsGFi8bI7dOrdYjy7iEjlyuXol8XAt4AjgYSZXQb8KTAb2Jyx6JPAx83sGSAMfNM5t63gidE1SkVERpPLjtI1wLIssz4/Yrkk8ImCpDoMXaNURCS7yvxGaXCrThcRGa4yS31o+EWtLiKSqSJLHZ0mQEQkq4osdR3RKCKSXWWWejD8MqBNdRGRYSqz1EsdQESkTFVmqXs6S6OISDYVWur+rY5TFxEZriJLfZAqXURkuIos9ZCnUXURkWwqstQH95Tq6BcRkeEqstR13WkRkewqs9SHrlGqVhcRyVSZpT50jVIREclUoaWu49RFRLKpzFIPbjX8IiIyXGWWuoZfRESyyunC02Z2InA/cJNz7jtmdiewGNgdLHKjc+4XZnY58EVgALjVOXd7ETIDGn4REckml2uUNgA3AytGzPqKc+7nI5a7FngH0A+sMrOfOef2FDAvkHlIo1pdRCRTLsMvfcAlwPbDLHcGsMo51+Wc6wWeAt45wXxZafhFRCS7XC48nQSSZjZy1ufM7EvALuBzwCygPWP+LmD2WM/d2lpPJBIeV2CAll09ANTX1xCPN4378cVWjplAufJRrtmUa3ymUq6cxtSzuAvY7Zxba2bXAP8IPD1imcOeoKWz80BeL75vXy8A3T0HaW/fn9dzFEs83lR2mUC58lGu2ZRrfKox11grg7xK3TmXOb6+HPgucC/+1vqgucDv8nn+w/F0jVIRkazyOqTRzO4zs6ODH5cB64BngaVm1mJmjfjj6U8WJOVb6DQBIiLZ5HL0y2LgW8CRQMLMLsM/GuanZnYA6AY+6ZzrDYZiHsLfh3mdc66rGKFD2lIXEckqlx2la/C3xke6L8uy9+IPwxSVF9KFp0VEsqnIb5SGg0H1gQGVuohIpoos9ZC21EVEsqrMUh/aUi9xEBGRMlORpe4FqbWlLiIyXEWWekhj6iIiWVV2qWtLXURkmMos9WBHaVpj6iIiw1RmqQdfPtKWuojIcJVZ6jqkUUQkq8osde0oFRHJqiJLXacJEBHJriJLfWhMXTtKRUSGqdBS16l3RUSyqcxS1/CLiEhWlVnq2lEqIpJVZZb60JZ6iYOIiJSZnK5RamYnAvcDNznnvmNm84E7gCiQAD7mnNtpZgngqYyHXuCcSxU69KEdpWp1EZFMuVzOrgH/8nWZF5v+BnCrc+5uM7sK+BLwd0CXc25ZMYJm8nTuFxGRrHIZfukDLgG2Z0z7LIcuZ9cOzChwrjFpR6mISHa5XKM0CSTNLHNaD4CZhYGrgOuDWbVm9hNgAXCfc+7bYz13a2s9kUh43KETSX9EJxIJE483jfvxxVaOmUC58lGu2ZRrfKZSrpzG1LMJCv0u4FHn3ODQzNXAj4A08ISZPeGcWz3ac3R2HsjrtVPBt476+pK0t+/P6zmKJR5vKrtMoFz5KNdsyjU+1ZhrrJVB3qWOv6N0o3PuusEJzrnvDd43sxXAScCopZ4vHdIoIpJdXqVuZpcD/c65r2VMM+BrwOVAGHgncG8hQo7keR6epzF1EZGRcjn6ZTHwLeBIIGFmlwEzgYNm9liw2EvOuc+a2RZgJTAALHfOrSxKaiAc8rSlLiIyQi47StcAy3J5Mufc3080UK6ikRCJlM7oJSKSqSK/UQoQi4ZJJFXqIiKZKrrU+xMqdRGRTJVb6pHQ0PHqIiLiq9xSj4Y1pi4iMkJFl7qGX0REhqvcUo+ESQ2kdVijiEiGyi31qB+9X+PqIiJDKrbUa2L+icD6NAQjIjKkYku9ubEGgH09/SVOIiJSPiq21Gc01wKwt7uvxElERMpH5Zb6NL/UO/er1EVEBlVsqU9vrgO0pS4ikqliS31wS31vt8bURUQGVW6pD46pa/hFRGRIxZZ6Q12UaCREp4ZfRESGVGype55HW3MtuzoPkNYVkEREgAoudYC58UZ6+1I6AkZEJJDTNUrN7ETgfuAm59x3zGw+cBf+tUh3AFc45/qCa5d+Ef9ydrc6524vUm4A5rU1sBrY2t7D9GDHqYjIVHbYLXUzawBuBlZkTL4euMU5dy6wCfhUsNy1wIX4l7/7GzObXvDEGebGGwDY1tFdzJcREakYuQy/9AGXANszpi0Dlgf3H8Av8jOAVc65LudcL/AU8M7CRX2refFGAN54U6UuIgK5XXg6CSTNLHNyg3NucCB7FzAbmAW0ZywzOH1Ura31RCLhcQXOtGjhTJobY2ze1kVbWyOe5+X9XIUUjzeVOkJWyjV+5ZpNucZnKuXKaUz9MEZr0sM2bGfngbxfNB5voqOjm2PnNLNmQzvrN7XT1lKX9/MVSjzeRHv7/lLHeAvlGr9yzaZc41ONucZaGeR79Eu3mQ026Fz8oZnt+FvrjJheVAvntwDgtuwt9kuJiJS9fEv9EeDS4P6lwIPAs8BSM2sxs0b88fQnJx5xbIOlvnGrSl1E5LDDL2a2GPgWcCSQMLPLgMuBO83sM8DrwA+ccwkzuwZ4CEgD1znnuoqWPDB/ZiN1NWFeeq2TdDpdNuPqIiKlkMuO0jX4R7uMdFGWZe8F7p14rNyFQh4nH9PGsy+9yRtvdrNgVnnuEBERmQwV/Y3SQYsXxgFYs2FXiZOIiJRWVZT6SUfPIBYJsca1H35hEZEqVhWlXhMLs+io6ezYfYCde/I/TFJEpNJVRakDnHacPwSzer2GYERk6qqaUj99YRuRcIhnXtypU/GKyJRVNaVeXxvltOPa2LH7AC++uqfUcURESqJqSh3g4qXzAfjtCztKnEREpDSqqtSPnjONWdPr+f2GDvb16ILUIjL1VFWpe57HeafOIZka4Ol1O0sdR0Rk0lVVqQOcfeIsPA+efelN7TAVkSmn6kq9qT7Gqce28fqb+9m8bV+p44iITKqqK3WAC5f4O0wfXr2lxElERCZXVZb68W9rYV68kTWunT37DpY6jojIpKnKUvc8j4uWzGMgnebR328rdRwRkUlTlaUOcOaiI2isi/L42m30JVKljiMiMimqttSjkTDLTptLz8Ekjz+nrXURmRryuvC0mV0JXJExaQmwGmgAeoJpfxtcYKNkLl46n0dWb+FXz77BuxbPIxKu2nWYiAiQZ6k7524Hbgcws/OAPwMWAZ90zq0rXLyJaayL8kenzOHXq7bwzLqdnHvKnFJHEhEpqkJsul4LfL0Az1MUFy+dTyQcYvlTr5FMDZQ6johIUXkT+dalmS0FrnLOfcLMHgP2AG3Ay8AXnXO9Yz0+mUylI5Fw3q+fq3+//wWWP/EKf3XpyVxy9lFFfz0RkSLzRpuR1/BLhr8A7gzu/yvwvHNus5l9F7gK+JexHtzZmf9ViuLxJtrb9+e07PmnzOHBZ17j/z20nlOObCUWLd6KZDy5JpNyjV+5ZlOu8anGXPF406jzJjr8sgx4GsA59zPn3OZg+gPASRN87oJpbohxweJ57O3u57G120sdR0SkaPIudTObA3Q75/rNzDOzR8ysJZi9DCibHaYA7z1jAbWxML985jX6+nXcuohUp4lsqc8GdgE459LArcAKM3sCmA/cMvF4hdNYF+XipfPZdyDBit9vLXUcEZGiyHtMPTgG/b0ZP98N3F2IUMVy8dK3sWLNVn71u9dZdupc6msnuktBRKS8TKlv49TXRnjPGW+j52CS5U+9Wuo4IiIFN6VKHeCiJfOZ2VLHI6u3smVXd6njiIgU1JQr9Vg0zEcvWshAOs1dDzkGdHUkEakiU67UAU4+ZgaLLc6mbV089cKOUscRESmYKVnqAB+54DhqomHu+c1munsTpY4jIlIQU7bUp0+r5YPnHEV3b4L/fHzz4R8gIlIBpmypA1y4ZB5z2xp4fO12Xt2hi1SLSOWb0qUeCYe4/KKFpIFbl7/Iwf5kqSOJiEzIlC51gOMXtPKed7yNNzt7+cnDG0sdR0RkQqZ8qQP86XlHs2BWE799YQfPvvRmqeOIiORNpY4/DPOXH1hETTTMDx9aT/veMU8DLyJStlTqgSOm1/OxixfS25fi1gdeJDWgqySJSOVRqWc4+8RZnHHCEWzeto/7f/taqeOIiIybSj2D53lccbHR1lzLL55+DfdGZ6kjiYiMi0p9hPraCJ/5wCI8z+PWB17St01FpKKo1LM4Zm4zf3LuUXTu79P4uohUlLyuEmFmy4B7gBeDSS8ANwB3AWFgB3CFc66vABlL4pIzF7BpWxfPb97NTx/dxEcvXFjqSCIihzWRLfXHnXPLgn+fB64HbnHOnQtsAj5VkIQlEgp5fOYDi5jT1sAjq7eyYo0ugSci5a+Qwy/LgOXB/QeACwv43CVRVxPhry87mca6KD9+eIO+mCQiZW8ipX6CmS03s9+a2UVAQ8Zwyy78C1NXvJktdXzpw6dQGwtz289f4uXX9pQ6kojIqLx0Hlf+MbO5wDn4F5o+GvgN0Oicmx7MPxb4oXPu7LGeJ5lMpSOR8LhfvxTWbtjFdbf9jmgkzNc/cxa2YHqpI4nI1OWNOiOfUh/JzFYCS4F651yvmZ0HfN45d9lYj2tv35/3i8fjTbS378/34XlZvX4X37v/RWLREF/80CksnN9SFrlyoVzjV67ZlGt8qjFXPN40aqnnNfxiZpeb2dXB/VnAEcAdwKXBIpcCD+bz3OVsyfEz+cwHF5FIDvDtu9fyooZiRKTM5Dumvhw4z8yeBO4H/gr4KvDxYNp04AeFiVhelh4/k6v+20kMDKT513ueZ+2mjlJHEhEZktdx6s65/cD7s8y6aGJxKsOpx7XxhctO4eb7nueW/3yBT39gEUuPn1nqWCIi+kZpvhYdNZ0vffhUopEQ37t/HU+v21HqSCIiKvWJWDi/hav//DRqYxFu+/nL3Hzf8yRTOqWAiJSOSn2Cjp4zja987HSmNcR4bmMHf3PT4ySSqVLHEpEpSqVeAPPijfzDf18CwGs79nHDT55jz76DJU4lIlORSr1AZjTX8m9XL+O80+axefs+rr9zlb59KiKTTqVeQNFIiL+9/HQ+cuFx9BxM8i//sZb7Ht+sU/eKyKRRqReY53lctGQ+13zsdFqn1fCLZ17nW/+xVsMxIjIpVOpFcsycZq79xFJOOWYG69/Yy9e+v5KVL79JIU7LICIyGpV6EU2rj/HXl53MFe82EskBvnf/i3znP1+gc3/FXjtERMpcXt8oldx5nsf5p83lhAWt/ODB9Ty3sYP1b+zlw+86lnNOmk0oNOp5eURExk1b6pPkiOn1XP2R07ji3UY6nebOX63n6z9czes7y+/scSJSuVTqkygUbLV/4y/O4MxFR/D6zv1cd+cq/v2Bl+jq6S91PBGpAhp+KYHp02r59PsXcc5Js/nxwxt45sWdrN3UzvvPPorzT5tLTawyLhwiIuVHW+oldMKR07n+ynfwkQuPI+R53P2bTXz5u0+zYs1WnUNGRPKiLfUSC4dCXLRkPmctmsWDz77BI2u28OOHN/Dwqi388dkLOPOEWUQjWveKSG7UFmWisS7KZcuO4Ya/PJvzT5/L7n0HueOX6/nqv/+OFWu2crA/WeqIIlIBtKVeZqY1xLjiYuOPz1zAgyvf4LHntvHjhzfwsyde4ZyTZ/OuxfOY2VJX6pgiUqbyLnUzuwE4N3iObwIfABYDu4NFbnTO/WLCCaeo6dNq+eiFC7nkzAU89tw2Hlu7nV+v2sLDq7aw+PiZXHD6XBbOb8HzdJy7iBySV6mb2fnAic65s8xsBvAc8CjwFefczwsZcKpraazhT849mvedfSSr1+/iwZVvsHr9Llav38WCWU2cf9pc3vH2mdTG9EeXiOS/pf4EsDK4vxdoAHQcXhFFwiHOXDSLd5xwBBve2MuKNVv5/cZ27vzVeu56yHHyMTO4cPE8bEErIW29i0xZ3kRPMGVmn8YfhkkBs4AYsAv4nHOuY6zHJpOpdCSidUG+2jt7eWTl6/zi6Vfp6va/vNTSVMNZJ83mkrOP4sjZ00qcUESKZNQttwmVupl9EPgfwMXAEmC3c26tmV0DzHPOfW6sx7e378/7xePxJtrby+8r9qXIlU6n2bxtH4//YRtrN3bQc9A/UuaI1jpOObaNU45t4+zT5tG5p2dSc+WiXH+PUL7ZlGt8qjFXPN40aqlPZEfpu4GvAu9xznUBKzJmLwe+m+9zy/h4nsex85o5dl4zydQAz23sYOVLb7Lu1T38etUWfr1qCw3/tY5FR7Zy6rFtnHj0DBrroqWOLSJFkO+O0mbgRuBC59yeYNp9wJedc68Ay4B1hQopuYuEQyw9fiZLj59JIjmAe6OTtZs6eOHVPax8eRcrX95FKFgJnHpsG6ccO4NZ0+t1FI1Ilch3S/3DQBtwt5kNTrsD+KmZHQC6gU9OPJ5MRDQS4sSjZ3Di0TNoa2tk7Us7Wbupgz9s6mDjlr1s2LKXu3+ziZmtdZxyTBsL5zdz7LwWmhtipY4uInnKq9Sdc7cCt2aZ9YOJxZFi8TyPeTMbmTezkfedfST7evp5fvNu/rCpg3Wv7eHh1Vt4ePUWAGa21HHsvGaOm+eX/OwZ9TqiRqRC6ODmKWpaQ4xzTp7NOSfPJpFM8cr2fWzc2sXGrV1s2tbF0+t28vS6nQA01EY4Zm5Q8nObOWr2NGJRHbUkUo5U6kI0Esbe1oq9rRWAgXSa7R09bNraxcate9m4tYvnN+/m+c3+l4XDIY8jZzX5O2fntnD0nGm0NMY0Li9SBlTq8hYhz2NevJF58UaWnTYXgL3dfWza2sWGrXvZtLWLV3fsZ/P2fTyEP2RTXxNhTlsDc9oaaGuu5fgFrcxta6CuRh8xkcmk/+MkJy2NNSw5fiZLjp8JQF9/ild27GPT1r28saub7R09vLJ9H5u2dQ173PRpNcxta2RuUPhz4w3MmdGgC4GIFIlKXfJSEwvz9gWtvH1B69C0RHKA7R09vLZzHzt2H2BbezfbOnp44ZXdvPDK7mGPb26McURLHfNmTaOxJky8pY625lriLXW0NNbogtwieVKpS8FEIyEWzGpiwaymYdN7DibY1t7Dto4etrf3sH13D7s6e9m4tYsNW7ve8jzhkEdbcy1tLXXEg9vBwm9rrqWxLqrxe5FRqNSl6Bpqoyyc38LC+S3DpidTA6QjYTa8spv2rl469h6ko6uX9uD2xVf3ZH2+2liYtuY64i21tDXX0dZSS3NDjJbGGqY31dDcWKOrRcmUpVKXkomEQ8TbGomOcv6hg/1JOvYeHCr8YcXf1cvW9u5Rn7uhNkJjfYymuiiNdVGa6qM01kdpqovRWBfcr48G82PU1YS19S9VQaUuZas2Fhn6wtRI6XSa7t4EHV0H6eg6SFd3H53dfezd38fe7n66evrp7k3Q3tnLQA4nrQuHvEPlXxelsT5GfHo9EXjLCqEpWCFEdYZRKUMqdalInufRVB+jqT7GUWOcYnggnaa3L0n3gQT7exP+7QG/8Ad/7u71p+3vTbBnXx9b23M7m2VNNExNLExDbYTaWISG2gg10TCxaIim+hj1NRFqYmFikRDRiD89FgkTjYaCaSEaaqO0aLhICkilLlUt5Hk01EZpqI1yRI6PSaYG6OlNEK2N8ca2vUMrgP0H+g+tBIKf+/pT7D+QoH1vL8lUfmeS9oBYLExtsJKIRcLUREPEov4KIRwOEfKgriZCLBqmtbmOZCJJTTRMJOyvHKLBbWTo1iMaCQe3/vxIJGOZcEhHGFUplbrICJFwiObGGuLxJuojuRVfOp2mPzlAfyJFX3+KfQcSHOxP0tef8qcnU/QnBkgE9xPJAbp6+tnddZBkaoC+/hR9iRQHg5VEfyJFamBiF7A5nHDIIxIJEfb82wMHk0QjIVoaY4Q8b+i7BLGIv0KIhEKEwx7hjJVBXU2EUMgj7Hk0NtbQdzDh/xzyCHkeXsgj5EE45K+YQiHv0D/P/xcOeXihQ8sMjpbFoqFhy4WGnhNI+3+FRcIhwiGPdNo/pDYW9VeCkZA3NOyW9ELs3dtL98EEtbEI4ZC/oksmB8Dzh/lSA2lIpwmHQ3je4BUoPP++Bx7+RA//r8TB3S/+fA+C149EPLwg78jPR5pDucNFXKGq1EUKwPM8fzgmGqapHtpa6ib8nKmBgaEVQWog7a8wEinqG2p5s2M//f0pEil/fnLoNk0imSKRSpNMDpBIZc7zbxOpgaF5g8+dTKWpiYboSwyw/0CCdDrNwf5UkKO4K5dq5QGjvXOxSIi/u2IJRx/x1v1FE6VSFylT4VCIupoQdTXDp8fjTbQ1Tt5FTpIpv/hTqTTJgQFSwTBTauDQSmFgIE1zcz0de7pJpdKkBtKk02kG0mkGBghug3/B/VRwP51m6DmGT/fnDT0m47kGt5gHs4VDHpFwiN7+JOngOQaHl6LRCD0H+tm+u4d58QbSadjX08+0hhjpdJq+xMDQsqnUAHDor4XBLex0GtKkCf4byja4DPh/4SWDx/sb/n5OPC/I6y/veR7pdJppDSN+sQWiUheRMUXCISJh4DDrkXi8ieba8jsiqBovZzcW7XIXEakiBd9SN7ObgDPx/0r5gnNuVaFfQ0REsivolrqZnQcc55w7C7gS+L+FfH4RERlboYdfLgD+C8A59zLQamajfzNEREQKqtDDL7OANRk/twfT9mVbuLW1nsgEvmodjzcdfqESUK7xKddcUL7ZlGt8plKuYh/9MuYR9p2dB/J+4qm2R3uilGv8yjWbco1PNeYaa2VQ6OGX7fhb5oPmADsK/BoiIjKKQpf6r4HLAMzsdGC7c678VpEiIlXKG/w2VKGY2f8C/ggYAK5yzv2hoC8gIiKjKnipi4hI6egbpSIiVUSlLiJSRVTqIiJVRKUuIlJFVOoiIlVEpS4iUkUq8iIZpT69r5ndAJyL//59E/gAsBjYHSxyo3PuF2Z2OfBF/GP2b3XO3V7ETMuAe4AXg0kvADcAdwFh/G/2XuGc65vMXEG2K4ErMiYtAVYDDUBPMO1vnXNrzOzLwIfwf7fXOed+WYQ8JwL3Azc5575jZvPJ8X0ysyhwJ7AASAGfdM69UsRcd+BfniIBfMw5t9PMEsBTGQ+9AH8DbbJy3UmOn/dJfr/uAeLB7OnA74B/xv9/YfCcVO3OuQ+ZWTPwE6AZ6AY+6pzbU6BcI/thFZP4+aq4Us88va+ZvR34PnDWJL7++cCJwevPAJ4DHgW+4pz7ecZyDcC1wDuAfmCVmf2sUB+cUTzunLssI8MdwC3OuXvM7J+BT5nZDyc7V7DSuD3IdB7wZ8Ai/A/suoy8RwF/jv/7bAaeNLOHnHOpQmUJfi83AysyJl9Pju8T8H5gr3PucjO7GP9/2g8XKdc38P9nv9vMrgK+BPwd0OWcWzbi8R+bxFyQ4+edSXy/nHMfypj/feC2Q7OGv1/4ZfqYc+5GM/s08PfBv4nmytYPK5jEz1fgM7sdAAAEFUlEQVQlDr+U+vS+T+BvSQLsxd/azHaqyTOAVc65LudcL/6W1TsnJ+KQZcDy4P4DwIVlkOta4OujzDsf+JVzrt851w68DpxQ4NfvAy7BP0/RoGXk/j5dAPwsWPYRCvfeZcv1WeC+4H47MGOMx09mrmzK4f0CwMwMaHHOrRzj8Zm5Bn/nhZCtH5YxiZ+vSiz1Wfgf8EGDp/edFM65lHNucMjgSuCX+H8mfc7MHjWz/zCztiw5dwGzixzvBDNbbma/NbOLgAbnXN+I1y9FLgDMbCmwxTm3M5h0vZk9YWb/ZmZ1k5HNOZcM/ifKNJ73aWi6c24ASJtZrBi5nHM9zrmUmYWBq/CHCwBqzewnZvaUmX0pmDZpuQK5ft4nOxfAF/C34gfNMrN7zezpYMiDEXkL9jkbpR8m9fNViaU+0pin9y0WM/sg/i/tc/jjZdc4594FrAX+MctDip1zI3Ad8EHg4/jDHZnDa6O9/mS+f3+BP14I8K/Al51zQ+cJyrJ8KX63432fipoxKPS7gEedc4NDDVcDnwYuBi43syWTnGsin/div18x4Bzn3G+CSbuBfwA+gr/v6+tmNrLAC55pRD/k8loFe78qsdRLfnpfM3s38FXgvcGfTyucc2uD2cuBk7LknMvh/4TNm3Num3Pup865tHNuM7ATf2iqbsTrT2quEZYBTwd5fxbkBP9P0kl/zzJ0j+N9Gpoe7NTynHP9Rcx2B7DROXfd4ATn3Pecc93BFuEKRrx3xc41zs/7ZL9f5wFDwy7Ouf3OuTuccwnnXAf+DvrjR+Qt6OdsZD8wyZ+vSiz1kp7eN9hrfiPwvsGdi2Z2n5kdHSyyDFgHPAssNbMWM2vEHxt7soi5Ljezq4P7s4Aj8Avh0mCRS4EHJztXRr45QLdzrt/MPDN7xMxagtnL8N+zR4E/NrNYsPxc4KViZ8Mfu8z1ffo1h8ZM3w/8hiIJhgr6nXNfy5hmwdCLZ2aRINeLk5xrPJ/3ScsVWAoMnRnWzM43s28H9xuAU4ENI3IN/s4nLFs/MMmfr4o8S6OV8PS+wZ7yf8T/YAy6A//PrAP4h0d90jm3y8wuA76Mf3jezc65HxcxVxP+mGsLEMMfinkO+CFQi7/T8ZPOucRk5srItxj4hnPuvcHPf4Z/tEEPsA240jl3wMw+D1weZPufGUMOhczxLeBI/MMEtwWvdyc5vE/BcMhtwHH4O+s+4ZzbUqRcM4GDHLoc5EvOuc+a2f8G3oX/+V/unPunSc51M3ANOXzeJznXn+J/7n/rnPtpsFwkeH3DP6Dhu865O4Ii/RH+zue9+IeLdhUgV7Z++HiQYVI+XxVZ6iIikl0lDr+IiMgoVOoiIlVEpS4iUkVU6iIiVUSlLiJSRVTqIiJVRKUuIlJF/j9M7v0+O93/0QAAAABJRU5ErkJggg==",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "# plot loss history\n",
        "\n",
        "pylab.plot(loss_history)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "colab_type": "code",
        "id": "tudLOfJazx4X",
        "outputId": "9c9cbd93-433a-48ff-e049-22ccb0cdd861"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[8.587499618530273, 0.9198104739189148]\n"
          ]
        }
      ],
      "source": [
        "# Found standard deviations (ground truth is 10 and 1):\n",
        "std_1 = torch.exp(log_var_a)**0.5\n",
        "std_2 = torch.exp(log_var_b)**0.5\n",
        "print([std_1.item(), std_2.item()])"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "multi-task-learning-example-pytorch.ipynb",
      "provenance": [],
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
